__author__ = 'Amit'

import bs4
import requests
import re
import sys
from time import sleep
import zipfile

#exampleSeq = 'GGCUAAUGGUGACCUCAUUGUUUUCAGAUCUCGGUGGCAGGUCCCUCGUCCAUUUACUCCAGUGGUAAUCAA'


######################## Nupack ########################

class NupackSession():
    patJobID = re.compile(r'/(\d+)', re.M)
    patJobToken = re.compile(r'token=(\w+)', re.M)

    def __init__(self, destZipPath, destExtractFolderPath):
        super().__init__()
        self.destZipPath = destZipPath
        self.destExtractFolderPath = destExtractFolderPath

    def analyzeUnpairedBasesProbability(self, sequence):
        self.sequence = sequence
        (self.jobID, self.jobToken) = self.startJobOneSpecie(sequence)
        self.resultsZipFileDest = self.destZipPath + '/' + self.jobID + '.zip'
        self.resultsFolder = self.destExtractFolderPath + '/' + self.jobID
        self.downloadResultsWhenReady(self.jobID, self.jobToken, self.resultsZipFileDest, self.resultsFolder)
        return self.getUnpairedBasesProbabilityData()

    def startJobOneSpecie(self, sequence):
        serverURL = "http://www.nupack.org/partition/new"
        self.temperature = '37.0'
        payload = {
            'partition_job[nucleic_acid_type]':             'RNA',
            'partition_job[temperature]':                   self.temperature,
            'partition_job[min_melt_temperature]':          '',
            'partition_job[is_melt]':                       '0',
            'partition_job[melt_temperature_increment]':    '',
            'partition_job[max_melt_temperature]':          '',
            'partition_job[num_sequences]':                 '1',
            'partition_job[max_complex_size]':              '1',
            'partition_sequence[0][name]':                  'strand1',
            'partition_sequence[0][contents]':              sequence,
            'partition_sequence[0][concentration]':         '',
            'partition_sequence[0][scale]':                 '-6',
            'partition_job[rna_parameter_file]':            'rna1995',
            'partition_job[dna_parameter_file]':            'dna1998',
            'partition_job[dangle_level]':                  '1',
            'partition_job[pseudoknots]':                   '0',
            'partition_job[na_salt]':                       '1.0',
            'partition_job[mg_salt]':                       '0.0',
            'partition_job[dotplot_target]':                '1',
            #'partition_job[dotplot_target]':               '0',
            'partition_job[predefined_complexes]':          '',
            'partition_job[filter_min_fraction_of_max]':    '0.001',
            'partition_job[filter_max_number]':             '10',
            'partition_job[email_address]':                 '',
            'commit': 'Analyze'
        }
        r = requests.post(serverURL, payload)
        jobID = re.search(NupackSession.patJobID, r.url).group(1)
        jobToken = re.search(NupackSession.patJobToken, r.url).group(1)
        return (jobID, jobToken)

    def getUnpairedBasesProbabilityData(self):
        datFileName = self.jobID + "_0_0_dp.dat"
        datFilePath = self.resultsFolder + '/' + self.jobID + '/' + self.temperature + '/' + datFileName
        rawData = open(datFilePath).read()
        unpairedBasesProbability = NupackSession.parseDatFile(rawData)
        return unpairedBasesProbability

    @staticmethod
    def chooseLowProbableBasePairingSequence(unpairedBasesProbability, seqLength):
        start = 0
        end = seqLength - 1
        dataLen = len(unpairedBasesProbability)

        oldScore = score = 0
        for b in unpairedBasesProbability[start:end]:
            score += b[1]
        maxScore = (start, score)

        for start in range(0, dataLen - seqLength + 1):
            if (score > oldScore):
                maxScore = (start, score)
            toRemove = unpairedBasesProbability[start][1]
            toAdd = unpairedBasesProbability[end][1]
            oldScore = score
            score = score - toRemove + toAdd
            end += 1

        chosenSeqStart = int(unpairedBasesProbability[maxScore[0]][0])
        return (chosenSeqStart, chosenSeqStart + seqLength - 1)



    def analyzeMultipleSpeciesConcentrations(self, sequences, titles, concentrations):
        self.sequences = sequences
        self.titles = titles
        self.concentrations = concentrations
        (self.jobID, self.jobToken) = self.startJobMultipleSpecies(sequences, titles, concentrations)
        self.resultsZipFileDest = self.destZipPath + '/' + self.jobID + '.zip'
        self.resultsFolder = self.destExtractFolderPath #+ '/' + self.jobID
        self.downloadResultsWhenReady(self.jobID, self.jobToken, self.resultsZipFileDest, self.resultsFolder)
        concentrationTable = self.parseConcentrationFile(self.getConcentrationsFile())
        return concentrationTable

    def startJobMultipleSpecies(self, sequences, titles, concentrations):
        numOfSpecies = len(sequences)
        serverURL = "http://www.nupack.org/partition/new"
        self.temperature = '37.0'
        payload = {
            'partition_job[nucleic_acid_type]':             'RNA',
            'partition_job[temperature]':                   self.temperature,
            'partition_job[min_melt_temperature]':          '',
            'partition_job[is_melt]':                       '0',
            'partition_job[melt_temperature_increment]':    '',
            'partition_job[max_melt_temperature]':          '',
            'partition_job[num_sequences]':                 str(numOfSpecies),
            'partition_job[max_complex_size]':              str(numOfSpecies),
            }
        for i in range(0, len(sequences)):
            payload['partition_sequence['+ str(i) +'][name]'] = titles[i]
            payload['partition_sequence['+ str(i) +'][contents]'] = sequences[i]
            payload['partition_sequence['+ str(i) +'][concentration]'] = concentrations[i]
            payload['partition_sequence['+ str(i) +'][scale]'] = '-6'
        payload.update({
            'partition_job[rna_parameter_file]':            'rna1995',
            'partition_job[dna_parameter_file]':            'dna1998',
            'partition_job[dangle_level]':                  '1',
            'partition_job[pseudoknots]':                   '0',
            'partition_job[na_salt]':                       '1.0',
            'partition_job[mg_salt]':                       '0.0',
            'partition_job[dotplot_target]':                '1',
            #'partition_job[dotplot_target]':               '0',
            'partition_job[predefined_complexes]':          '',
            'partition_job[filter_min_fraction_of_max]':    '0.001',
            'partition_job[filter_max_number]':             '10',
            'partition_job[email_address]':                 '',
            'commit': 'Analyze'
        })
        r = requests.post(serverURL, payload)
        jobID = re.search(NupackSession.patJobID, r.url).group(1)
        jobToken = re.search(NupackSession.patJobToken, r.url).group(1)
        return (jobID, jobToken)

    def getConcentrationsFile(self):
        eqFileName = self.jobID + ".eq"
        eqFilePath = self.resultsFolder + '/' + self.jobID + '/' + self.temperature + '/' + eqFileName
        finish = 0
        rawData = ''
        while not finish:
            try:
                rawData = open(eqFilePath).read()
                finish = 1
            except:
                #print('retry.')
                sleep(25)
                self.downloadResultsWhenReady(self.jobID, self.jobToken, self.resultsZipFileDest, self.resultsFolder)
        if not finish:
            print('No concentrations!')
        return rawData

    def parseConcentrationFile(self, rawData):
        patSeparator = re.compile(r'\s+')
        lines = rawData.split('\n')
        lines = [re.split(patSeparator, line) for line in lines if (len(line) > 0 and line[0] != '%')]

        concentrationTable = []
        for i in range(0, len(lines)):
            strComplex = []
            for j in range(0, len(self.sequences)):
                #if (int(lines[i][j]) > 0):
                strComplex += (self.titles[j]+'-') * int(lines[i][2+j])
                #elif (int(lines[i][j]) == 1):
                #    strSp += [self.titles[j]]
            strComplex = ''.join(strComplex[:-1])
            tupleComplex = tuple(lines[i][2:2+len(self.sequences)])
            concentrationTable += [[tupleComplex, strComplex, float(lines[i][3+len(self.sequences)])]]
        return concentrationTable



    def downloadResultsWhenReady(self, jobID, jobToken, destZipPath, destExtractFolderPath):
        #resultPageURL = \
        #    'http://www.nupack.org/partition/histogram_detail/' + jobID + \
        #    '?token=' + jobToken + \
        #    '&strand_id=0'
        jobFilesURL = \
            'http://www.nupack.org/partition/download_tar/' + jobID + \
            '?token=' + jobToken
        if(returnWhenExistsURL(jobFilesURL, 1)):
            sleep(5)    # waiting for server to zip all files
            filesBinary = downloadFile(jobFilesURL)
            saveFileTo(filesBinary, destZipPath)
            self.extractResultsToDestination(destExtractFolderPath)

    @staticmethod
    def parseDatFile(rawData):
        patSeparator = re.compile(r'\s+')
        lines = rawData.split('\n')
        lines = [re.split(patSeparator, line) for line in lines if (len(line) > 0 and line[0] != '%')]

        unpairedBasesProbability = []
        for line in lines:
            if (line[1] == '-1'):
                unpairedBasesProbability.append([line[0], float(line[2])])

        return unpairedBasesProbability

    # Extract zip file to destination folder
    def extractResultsToDestination(self, destFolder):
        with zipfile.ZipFile(self.resultsZipFileDest) as zip:
            zip.extractall(destFolder)


######################## InvivoGen Wizard ########################

class InvivoGenWizardSession():
    def __init__(self, targetGene, seqLength):
        super().__init__()
        self.targetList = InvivoGenWizardSession.constructTargetSeqList(targetGene, seqLength)

    def getTargetSeqList(self):
        return self.targetList

    @staticmethod
    def constructTargetSeqList(targetGene, seqLength):
        targetServerURL = "http://www.sirnawizard.com/siRNA.php"
        payload = {
            'sequence': targetGene,
            'taille_motif': str(seqLength),
            'database': './Base_ARNm_human',
            'mirna_database': './FAD_miRNA_humain'
        }
        r = requests.post(targetServerURL, payload)
        page = r.text
        soup = bs4.BeautifulSoup(page)
        #sequencesDiv = soup.find('div', {'id': 'base'})
        sequencesDiv = soup.find_all('form')[1]
        #print(sequencesDiv)
        sequencesDivList = sequencesDiv.find_all('font', {'size': '3', 'color': '#003399'})
        targetList = list(map((lambda x: x.contents[0].strip()), sequencesDivList))
        #print(targetList)
        return targetList


######################## Auxiliary functions ########################

# Download and return the file (binary) from URL
def downloadFile(url):
    r = requests.get(url, stream=True)
    file = bytearray()
    if r.ok:
        for block in r.iter_content(1024):
            file += block
    return file

# Saves a given binary file to pathname
def saveFileTo(file, outputPathName):
    with open(outputPathName, "wb") as f:
        f.write(file)
        f.close()

# Returns True after URL become exists
def returnWhenExistsURL(refreshURL, retryLapTime):
    results = requests.get(refreshURL)
    #print('')
    #print(refreshURL)
    #sys.stdout.write('waiting for page...')
    #sys.stdout.flush()
    timeCounter = 0
    timeout = 45
    while (results.status_code != 200 and timeCounter != timeout):
        sleep(retryLapTime)
        #sys.stdout.write('.')
        #sys.stdout.flush()
        timeCounter += 1
        results = requests.get(refreshURL)
    if (timeCounter == timeout):
        #print(' Timeout.')
        exit()
    #else:
        #print(' Done.\n')
    return True
